--- title: CycleGAN training loop keywords: fastai sidebar: home_sidebar summary: "Defines the loss and training loop functions/classes for CycleGAN." description: "Defines the loss and training loop functions/classes for CycleGAN." ---
Let's start out by writing the loss function for the CycleGAN model. The main loss used to train the generators. It has three parts:
Let's now write the main callback to train a CycleGAN model.
Fastai's callback system is very flexible, allowing us to adjust the traditional training loop in any conceivable way possible. Let's use it for GAN training.
We have the _set_trainable function that is called with arguments telling which networks need to be put in training mode or which need to be frozen.
When we start training before_train, we define separate optimizers. self.opt_G for the generators and self.opt_D for the discriminators. Then we put the generators in training mode (with _set_trainable).
Before passing the batch into the model (before_batch), we have to fix it since the domain B image was kept as the target, but it also needs to be passed into the model. We also set the inputs for the loss function.
In after_batch, we calculate the discriminator losses, backpropagate, and update the weights of both the discriminators. The main training loop will train the generators.
The original CycleGAN paper started with a period of constant learning rate and a period of linearly decaying learning rate. Let's make a scheduler to implement this (with other possibilities as well). Fastai already comes with many types of hyperparameter schedules, and new ones can be created by combining existing ones. Let's see how to do this:
p = torch.linspace(0.,1,200)
plt.plot(p, [combined_flat_anneal(0.5,1,1e-2,curve_type='linear')(o) for o in p],label = 'linear annealing')
plt.plot(p, [combined_flat_anneal(0.5,1,1e-2,curve_type='cosine')(o) for o in p],label = 'cosine annealing')
plt.plot(p, [combined_flat_anneal(0.5,1,1e-2,curve_type='exponential')(o) for o in p],label = 'exponential annealing')
plt.legend()
plt.title('Constant+annealing LR schedules')
Now that we have the learning rate schedule, we can write a quick training function that can be added as a method to Learner using @patch decorator. Function is inspired by this code.
from fastai.test_utils import *
learn = synth_learner()
learn.fit_flat_lin(n_epochs=2,n_epochs_decay=2)
learn.recorder.plot_sched()
Below, we now define a method for initializing a Learner with the CycleGAN model and training callback.
horse2zebra = untar_data('https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/horse2zebra.zip')
folders = horse2zebra.ls().sorted()
trainA_path = folders[2]
trainB_path = folders[3]
testA_path = folders[0]
testB_path = folders[1]
dls = get_dls(trainA_path, trainB_path,num_A=100)
cycle_gan = CycleGAN(3,3,64)
learn = cycle_learner(dls, cycle_gan,show_img_interval=1)
learn.show_training_loop()
test_eq(type(learn),Learner)
learn.lr_find()
learn.fit_flat_lin(5,5,2e-4)
learn.recorder.plot_loss(with_valid=False)